package cn.com.openportal.ffw.common.utils;

import cn.com.openportal.ffw.common.annotation.DataFilterNative;
import cn.com.openportal.ffw.common.exception.RRException;
import cn.com.openportal.ffw.modules.sys.entity.SysUserEntity;
import cn.com.openportal.ffw.modules.sys.service.SysDeptService;
import org.apache.commons.lang.StringUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * @author LeeSon QQ & WX:25901875
 * @version V1.0
 * @Package cn.com.openportal.ffw.common.utils
 * @create 2020-01-17 21:00
 * @Copyright © 2019 LeeSon QQ & WX:25901875
 */
@Component
public class SQLFilterNativeUtils {
    @Autowired
    private SysDeptService sysDeptService;

    /**
     * 获取数据过滤的SQL
     */
    public String getSQLFilter(SysUserEntity user, JoinPoint point) {
        DataFilterNative dataFilter = getDataFilterAnnotation(point);
        //获取表的别名
        String tableAlias = dataFilter.tableAlias();
        if (StringUtils.isNotBlank(tableAlias)) {
            tableAlias += ".";
        }
        //部门ID列表
        Set<Long> deptIdList = getDeptIdList(user, dataFilter);

        StringBuilder sqlFilter = new StringBuilder();
        sqlFilter.append(" (");
        if (deptIdList.size() > 0) {
            sqlFilter.append(tableAlias).append(dataFilter.deptId()).append(" in(").append(StringUtils.join(deptIdList, ",")).append(")");
        }
        //只能查询本人数据
        if (dataFilter.user()) {
            if (deptIdList.size() > 0) {
                sqlFilter.append(" and ");
            }
            sqlFilter.append(tableAlias).append(dataFilter.userId()).append("=").append(user.getUserId());
        }
        sqlFilter.append(")");
        if (sqlFilter.toString().trim().equals("()")) {
            return null;
        }
        return sqlFilter.toString();
    }

    public void doSQLtoDeleteAll(SysUserEntity user, JoinPoint point) throws Exception {
        Object obj = SpringContextUtils.getBean(getServiceName(point));
        String methodName = "deleteByQueryNative";
        Method[] methods = obj.getClass().getDeclaredMethods();
        for (Method method : methods) {
            if (method.getName().equals(methodName)) {
                Class<?>[] ps = method.getParameterTypes();
                Class<?> p = ps[0];
                Object query = p.newInstance();
                doDataFilter(user, point, query);
                doMethodByParam(obj, methodName, query);
                throw new RRException(getMethodName(point) + " Transition [DeleteAll] to [DeleteQuery]");
            }
        }
    }

    public void doDataFilter(SysUserEntity user, JoinPoint point, Object obj) throws Exception {
        Class<?> clazz = obj.getClass();
        Method method = clazz.getMethod(Constant.SQL_FILTER, String.class);
        method.invoke(obj, getSQLFilter(user, point));
    }

    public void isCanSQLtoIDS(SysUserEntity user, JoinPoint point, Object obj) throws Exception {
        Long[] ids = (Long[]) obj;
        isCanSQL(user, point, obj, ids);
    }

    public void isCanSQLtoID(SysUserEntity user, JoinPoint point, Object obj) throws Exception {
        Long id = (Long) obj;
        Long[] ids = new Long[]{id};
        isCanSQL(user, point, obj, ids);
    }

    public void isCanSQLtoUpdate(SysUserEntity user, JoinPoint point, Object obj) throws Exception {
        Long id = (Long) doMethodByNoParam(obj, "getId");
        Long[] ids = new Long[]{id};
        isCanSQL(user, point, obj, ids);
    }

    public void isCanSQL(SysUserEntity user, JoinPoint point, Object obj, Long[] ids) throws Exception {
        DataFilterNative dataFilter = getDataFilterAnnotation(point);
        Set<Long> deptIdList = getDeptIdList(user, dataFilter);
        for (Long id : ids) {
            Object e = doMethodByParam(SpringContextUtils.getBean(getServiceName(point)), "getByKeyNative", id);
            Long deptId = (Long) doMethodByNoParam(e, "getDeptId");
            //部门ID列表
            if (null != deptId) {
                if (deptIdList.contains(deptId)) {
                    //只能查询本人数据
                    if (dataFilter.user()) {
                        Long createUserId = (Long) doMethodByNoParam(e, "getCreateUserId");
                        if (null != createUserId) {
                            if (user.getDeptId().longValue() != createUserId.longValue()) {
                                throw new RRException("Cant Do " + getMethodName(point));
                            }
                        }
                    }
                } else {
                    throw new RRException("Cant Do " + getMethodName(point));
                }
            }
        }
    }

    private Set<Long> getDeptIdList(SysUserEntity user, DataFilterNative dataFilter) {
        //部门ID列表
        Set<Long> deptIdList = new HashSet<>();
        //添加本级部门
        deptIdList.add(user.getDeptId());
        //用户子部门ID列表
        if (dataFilter.subDept()) {
            List<Long> subDeptIdList = sysDeptService.getSubDeptIdList(user.getDeptId());
            deptIdList.addAll(subDeptIdList);
        }
        return deptIdList;
    }

    public static String getServiceName(JoinPoint point) {
        String serviceName = StringUtils.uncapitalize(getName(point.getTarget().getClass().getName()).replaceAll("ControllerNative", "")) + "ServiceImplNative";
        return serviceName;
    }

    public static String getName(String name) {
        int pos = name.lastIndexOf(".");
        String result = name.substring(pos + 1);
        return result;
    }

    public static Object doMethodByParam(Object obj, String methodName, Object p) throws Exception {
        Class<?> clazz = obj.getClass();
        Method method = clazz.getMethod(methodName, p.getClass());
        return method.invoke(obj, p);
    }

    public static Object doMethodByNoParam(Object obj, String methodName) throws Exception {
        Class<?> clazz = obj.getClass();
        Method method = clazz.getMethod(methodName);
        return method.invoke(obj);
    }

    public static String getMethodName(JoinPoint point) {
        MethodSignature signature = (MethodSignature) point.getSignature();
        return signature.getName();
    }

    public static DataFilterNative getDataFilterAnnotation(JoinPoint point) {
        MethodSignature signature = (MethodSignature) point.getSignature();
        DataFilterNative dataFilter = signature.getMethod().getAnnotation(DataFilterNative.class);
        return dataFilter;
    }
}
